# -*- coding:utf-8 -*-

from tensorflow.python import pywrap_tensorflow
from utils.io_utils import convert_abspath


def watch_model(model_path):
    """
    查看model结构
    :param model_path:
    :return:
    """
    reader = pywrap_tensorflow.NewCheckpointReader(model_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    print("have {} tensor".format(len(var_to_shape_map)))
    for key in var_to_shape_map:
        print("tensor_name:{}, shape:{}".format(key, reader.get_tensor(key).shape))


if __name__ == '__main__':
    single_layer_path = 'data/model/dnn/single_layer/10101422/model'
    lenet5_layer_path = 'data/model/cnn/lenet5/10151326/model'
    rnn_layer_path = 'data/model/rnn/drnn/10151547/model'
    path = convert_abspath(rnn_layer_path)
    watch_model(path)
